{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ~/contracode\n",
    "import numpy as np\n",
    "import pickle\n",
    "import gzip\n",
    "from tqdm.auto import tqdm\n",
    "import pandas as pd\n",
    "import time\n",
    "from typing import Iterable\n",
    "from loguru import logger\n",
    "import multiprocessing as mp\n",
    "# import modin.pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data.dataset import Dataset\n",
    "from transformers import BertTokenizerFast\n",
    "from representjs import DATA_DIR\n",
    "from tqdm import tqdm\n",
    "# import swifter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Shard the train set for tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_df = load_data(\"/data/ajay/contracode/data/hf_data/augmented_pretrain_df.train.pickle.gz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from multiprocessing.pool import ThreadPool\n",
    "\n",
    "def split_df(df, save_pattern, num_chunks=160):\n",
    "    # Split data frame into chunks\n",
    "    chunk_size = int(df.shape[0] / num_chunks)\n",
    "    def save_data(data):\n",
    "        chunk_i, start, chunk_size, df, save_pattern = data\n",
    "        save_path = save_pattern.format(chunk_i)\n",
    "        df_subset = df.iloc[start : start + chunk_size]\n",
    "        df_subset.to_pickle(save_path)\n",
    "        print(\"Saved \", save_path)\n",
    "    items = [(i, start, chunk_size, df, save_pattern) for i, start in enumerate(range(0, df.shape[0], chunk_size))]\n",
    "    with ThreadPool(64) as pool:\n",
    "        pool.map(save_data, items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chunk_dir = \"/data/ajay/contracode/data/hf_data/train_chunks\"\n",
    "!mkdir -p {chunk_dir}\n",
    "split_df(train_df, chunk_dir + \"/augmented_pretrain_df.{:04d}.train.pickle.gz\", 160)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tqdm.pandas()\n",
    "\n",
    "path = \"/data/ajay/contracode/data/hf_data/train_chunks/augmented_pretrain_df.0000.train.pickle.gz\"\n",
    "\n",
    "def load_tokenizer(path=\"data/vocab/8k_bpe/8k_bpe-vocab.txt\"):\n",
    "    return BertTokenizerFast(path, clean_text=True, lowercase=False, strip_accents=True, unk_token=\"<unk>\")\n",
    "\n",
    "def load_data(path):\n",
    "    return pd.read_pickle(path)\n",
    "\n",
    "tokenizer = load_tokenizer()\n",
    "df_shard = load_data(path)\n",
    "df_shard['toks'] = df_shard['text'].progress_apply(lambda x: np.asarray(tokenizer.encode(x)))\n",
    "df_shard = df_shard[['data_idx', 'toks']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.contrib.concurrent import process_map\n",
    "\n",
    "dfs = []\n",
    "\n",
    "files = []\n",
    "for i in tqdm(list(range(161))):\n",
    "    path = f\"/data/ajay/contracode/data/hf_data/train_chunks_tokenized/augmented_pretrain_tokenized_df.{i:04d}.train.pickle.gz\"\n",
    "    files.append(path)\n",
    "\n",
    "def load_file(fname):\n",
    "    out = pd.read_pickle(fname)\n",
    "    return out\n",
    "    \n",
    "dfs = process_map(load_file, files, max_workers=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df.info(memory_usage='deep')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df.to_pickle('/data/ajay/contracode/data/hf_data/merged_tok.pickle.gz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Repack data into plain pickle format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyarrow as pa\n",
    "import pyarrow.feather as feather\n",
    "\n",
    "%time test_df = pd.read_pickle('/data/ajay/contracode/data/hf_data/augmented_pretrain_df_tok.test.pickle.gz')\n",
    "%time feather_test_df = pa.Table.from_pandas(test_df)\n",
    "%time feather.write_feather(feather_test_df, '/data/ajay/contracode/data/hf_data/feather_tok/test_lz4.feather', compression='lz4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%time train_df = pd.read_pickle('/data/ajay/contracode/data/hf_data/merged_tok.pickle.gz')\n",
    "%time feather_train_df = pa.Table.from_pandas(train_df)\n",
    "%time feather.write_feather(feather_train_df, '/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather', compression='lz4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chunk_size = int(len(train_df) / 10)\n",
    "for i in tqdm(list(range(11))):\n",
    "    %time sampled_train_df = train_df[chunk_size * i : chunk_size * i + chunk_size]\n",
    "    %time feather_train_df = pa.Table.from_pandas(sampled_train_df)\n",
    "    %time feather.write_feather(feather_train_df, f'/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather.{i:02d}', compression='lz4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%time feather.write_feather(feather_train_df, '/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather', compression='lz4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int(len(train_df) * 0.1)\n",
    "len(train_df[int(len(train_df) * 0.1):])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(feather_train_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%time feather.read_feather('/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather.00')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyarrow.feather as feather\n",
    "from tqdm.contrib.concurrent import thread_map\n",
    "\n",
    "files = [f'/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather.{i:02d}' for i in range(11)]\n",
    "%time dfs = thread_map(feather.read_feather, files, max_workers=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "files = [f'/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather.{i:02d}' for i in range(11)]\n",
    "%time dfs = thread_map(pd.read_feather, files, max_workers=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "from pathlib import Path\n",
    "glob.glob('/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather' + '*')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
